from transformers import pipeline
import torch 
import torch.nn as nn
import pandas as pd
import numpy as np
import pdb
from tqdm import tqdm
import time

class FudgeWrapper(torch.nn.Module):

    def __init__(self, lm_head: nn.Module, experiment: str, tokenizer, device='cuda', k=50):
        """
        W shape: d x 2
        """
        super(FudgeWrapper, self).__init__()

        self.base_layer = lm_head
        self.base_layer.eval()
        self.device = device
        self.tokenizer = tokenizer

        assert experiment in ('toxicity', 'sentiment')
        if experiment == 'toxicity':
            self.discriminator = pipeline("text-classification", model="cardiffnlp/twitter-roberta-base-offensive")
            self.toxic_label = 'offensive'
            self.nontoxic_label = ['non-offensive']
        else:
            self.discriminator = pipeline("text-classification", model='cardiffnlp/twitter-roberta-base-sentiment-latest')
            self.toxic_label = 'negative'
            self.nontoxic_label = ('positive', 'neutral')

        self.k = k
        self.prompts = []
        self.latency = []

    def store_batch_prompts(self, prompts):
        self.prompts = prompts

    def forward(self, x, *args, **kwargs):
        t = time.time()
        # Produce the logits
        x_logits = self.base_layer(x, *args, **kwargs)
        batch_size, seq_len, vocab_size = x_logits.shape[0], x_logits.shape[1], x_logits.shape[-1]

        # Take the top k 
        top_k, top_k_indices = torch.topk(x_logits, self.k, dim=2) # batch size x k
        top_k_indices = top_k_indices[:, -1, :]

        # Make k strings by getting the top k tokens.
        decoded_tokens = [self.tokenizer.batch_decode(token_idx, skip_special_tokens=True) for token_idx in top_k_indices] # batch_size x len(k)

        extended_prompts_nontoxic_probs = []
        for i, sequence_top_k in tqdm(enumerate(decoded_tokens)):
            sequence_extensions = [self.prompts[i] + ' ' + tok for tok in sequence_top_k]
            
            # Get p(a | sequence) for each k from the discriminator.
            scores = self.discriminator(sequence_extensions)
            probs_nontoxic = []
            for score in scores:
                if score['label'] in self.nontoxic_label: 
                    probs_nontoxic.append(score['score'])
                elif score['label'] == self.toxic_label:
                    probs_nontoxic.append(1 - score['score'])

            extended_prompts_nontoxic_probs.append(probs_nontoxic)

        # Multiply by the logits again
        extended_prompts_nontoxic_probs = torch.Tensor(extended_prompts_nontoxic_probs).to(self.device)
        modified_top_k = top_k * extended_prompts_nontoxic_probs.unsqueeze(1)
        modified_top_k = modified_top_k.to(self.device)

        # New logits
        new_logits = torch.zeros(x_logits.shape).to(self.device)
        new_logits[:,:,top_k_indices.squeeze(0)] = modified_top_k

        # Zero out the rest-- since we're doing greedy decoding the other values don't matter.
        # Step 2: Compute min(Y, axis=1) - 1
        min_topk_minus_1 = modified_top_k.min(dim=2, keepdim=True)[0] - 0.1

        # Step 3: Create a mask for the other elements
        mask = torch.ones_like(new_logits, dtype=torch.bool)
        mask[:,:,top_k_indices.squeeze(0)] = False

        # Step 4: Set all other elements to min(Y, axis=1) - 1
        new_logits[mask] = min_topk_minus_1.expand(batch_size, seq_len, vocab_size)[mask]

        self.latency.append(time.time() - t)
        return new_logits